load("nelson_data.RData")
attach(kkk)

#Data
T <- as.matrix(vidcdum)
M <- as.matrix(idisord)
Y <- as.matrix(kspeech)

r.sq.m <- seq(0,.98,.01)
r.sq.y <- seq(0,.98,.01)

r.sq.m <- seq(0,1,.01)
r.sq.y <- seq(0,1,.01)

r.sq.m.str <- summary(lm(M ~ T))$r.squared
r.sq.y.str <- summary(lm(Y ~ T + M))$r.squared

r.tilde.m <- (1 - r.sq.m.str)*r.sq.m
r.tilde.y <- (1 - r.sq.y.str)*r.sq.y

eq.m <- lm(M ~ T)
eq.y <- lm(Y ~ T + M)
sigma.2 <- var(eq.m$resid)
e.3.star <- eq.y$coef[3]*eq.m$resid + eq.y$resid
sigma.3.star <- var(e.3.star)
sigma.23.star <- cov(eq.m$resid, e.3.star)

cbind(sqrt(r.tilde.m*r.tilde.y), r.tilde.m, r.tilde.y)

med.eff.pos <- function(a,b){
    rho.sq <- a*b
    rho <- sqrt(rho.sq)
    eq.m$coef[2]*((sigma.23.star/sigma.2) - (rho/sqrt(sigma.2)) * sqrt((1/(1-rho.sq))*(sigma.3.star - sigma.23.star/sigma.2)))
}

med.eff.neg <- function(a,b){
    rho.sq <- a*b
    rho <- - sqrt(rho.sq)
    eq.m$coef[2]*((sigma.23.star/sigma.2) - (rho/sqrt(sigma.2)) * sqrt((1/(1-rho.sq))*(sigma.3.star - sigma.23.star/sigma.2)))
}

med.eff.pos.tilde <- function(a,b){
    rho.sq <- a*b/((1 - r.sq.m.str)*(1 - r.sq.y.str))
    rho <- sqrt(rho.sq)
    eq.m$coef[2]*((sigma.23.star/sigma.2) - (rho/sqrt(sigma.2)) * sqrt((1/(1-rho.sq))*(sigma.3.star - sigma.23.star/sigma.2)))
}

med.eff.neg.tilde <- function(a,b){
    rho.sq <- a*b/((1 - r.sq.m.str)*(1 - r.sq.y.str))
    rho <- - sqrt(rho.sq)
    eq.m$coef[2]*((sigma.23.star/sigma.2) - (rho/sqrt(sigma.2)) * sqrt((1/(1-rho.sq))*(sigma.3.star - sigma.23.star/sigma.2)))
}

r.p <- outer(r.sq.m, r.sq.y, med.eff.pos)
r.n <- outer(r.sq.m, r.sq.y, med.eff.neg)

r.til.p <- outer(r.tilde.m, r.tilde.y, med.eff.pos.tilde)
r.til.n <- outer(r.tilde.m, r.tilde.y, med.eff.neg.tilde)

# Plot for R2
r.n.tmp <- r.n[nrow(r.n):1, ncol(r.n):1]

r.m <- seq(-1,0, by=.01)
r.y <- seq(-1,0, by=.01)

lab.1 <- seq(0,1, by=.1)
lab.2 <- seq(1,0, by=-.1)

pdf("r-sq-1.pdf", width=7, height=7, onefile=FALSE, paper="special")

par(mfrow=c(2,2))

par(mar=c(5,5,5,2)) 
plot(c(0,1), c(0,1), type="n", axes=FALSE, xlab="", ylab="")

par(mar=c(0,0,5,2))
contour(r.sq.m, r.sq.y, r.p, levels=seq(-2.2, -0.6, 0.2), ylim=c(0,1), asp=1, axes=FALSE)
text(.3,.18, expression(paste("sgn", (lambda[2]*lambda[3])==1)))
axis(2,at=seq(0,1,by=.1), labels = lab.1, line=-2.1, cex.axis=.8)
axis(1,at=seq(0,1,by=.1), labels = lab.1, line=-.6, cex.axis=.8)
mtext(expression(paste(R[M]^2,"*")), side=1, line=2, cex=.8)
mtext(expression(paste(R[Y]^2,"*")), side=2, line=.5, cex=.8)

par(mar=c(2,4,1,0))
contour(r.m, r.y, r.n.tmp, levels=seq(-0.4, 1.4, 0.2), ylim=c(-1,0), asp=1, axes=FALSE)
text(-.35, -.17, expression(paste("sgn", (lambda[2]*lambda[3])==-1)))
axis(3,at=seq(-1,0,by=.1), labels = lab.2, line=-1.1, cex.axis=.8)
axis(4,at=seq(-1,0,by=.1), labels = lab.2, line=-.6, cex.axis=.8)
mtext(expression(paste(R[M]^2,"*")), side=3, line=1.5, cex=.8)
mtext(expression(paste(R[Y]^2,"*")), side=4, line=2, cex=.8)

title(main = "Proportion of unexplained variance \n explained by an unobserved confounder", outer=TRUE, line=-3)

dev.off()

# Plot for R^2 tilde
r.n.til <- r.til.n[nrow(r.n):1, ncol(r.n):1]

r.tilde.m <- as.matrix(r.tilde.m)
r.tilde.y <- as.matrix(r.tilde.y)

r.m.2 <- r.tilde.m[length(r.tilde.m):1,]*-1
r.y.2 <- r.tilde.y[length(r.tilde.y):1,]*-1

lab.1 <- seq(0,1, by=.1)
lab.2 <- seq(1,0, by=-.1)

pdf("r-sq-2.pdf", width=7, height=7, onefile=FALSE, paper="special")

par(mfrow=c(2,2))

par(mar=c(5,5,5,2)) 
plot(c(0,1), c(0,1), type="n", axes=FALSE, xlab="", ylab="")

par(mar=c(0,0,5,2))
contour(r.tilde.m, r.tilde.y, r.til.p, levels=seq(-2.2, -0.6, 0.2), ylim=c(0,1), asp=1, axes=FALSE)
text(.7,.6, expression(paste("sgn", (lambda[2]*lambda[3])==1)))
axis(2,at=seq(0,1,by=.1), labels = lab.1, line=-2.3, cex.axis=.8)
axis(1,at=seq(0,1,by=.1), labels = lab.1, line=-.6, cex.axis=.8)
mtext(expression(paste(tilde(R)[M]^2)), side=1, line=2, cex=.8)
mtext(expression(paste(tilde(R)[Y]^2)), side=2, line=0, cex=.8)

par(mar=c(2,4,1,0))
contour(r.m.2, r.y.2, r.n.til, levels=seq(-0.4, 1.4, 0.2), ylim=c(-1,0), asp=1, axes=FALSE)
text(-.8, -.6, expression(paste("sgn", (lambda[2]*lambda[3])==-1)))
axis(3,at=seq(-1,0,by=.1), labels = lab.2, line=-.8, cex.axis=.8)
axis(4,at=seq(-1,0,by=.1), labels = lab.2, line=-.6, cex.axis=.8)
mtext(expression(paste(tilde(R)[M]^2)), side=3, line=1.5, cex=.8)
mtext(expression(paste(tilde(R)[Y]^2)), side=4, line=2, cex=.8)

title(main = "Proportion of original variance \n explained by an unobserved confounder", outer=TRUE, line=-3)

dev.off()

